import typing
import copy

from pyodbc import Row, Connection, Cursor, connect

from common.logger import logger
from common.utils import split_array
from common.globalconfig import global_config

settings = global_config


class DB:

    def __init__(
            self,
            driver: str ,
            host: str ,
            database: str ,
            user: str ,
            password: str,
    ) -> None:

        self.driver = driver
        self.host = host
        self.database = database
        self.user = user
        self.password = password
        self.conn: Connection
        self.cursor: Cursor
        self.conn, self.cursor = self._connect()

    def _connect(self) -> tuple[Connection, Cursor]:
        try:
            conn = connect(driver=self.driver, host=self.host, database=self.database, user=self.user,
                                  password=self.password)
        except Exception as e:
            logger.info(f'数据库连接异常 {e}')
        else:
            logger.info(f"{self.database}链接成功")
            return conn, conn.cursor()

    def exec(self, sql, *params, autocommit=True):
        # 通用
        logger.info(sql)
        self.cursor.execute(sql, *params)
        if autocommit:
            self.cursor.commit()

    def clear(self, table: str, autocommit=True):
        # 清空表
        sql = f'truncate table {table}'
        self.cursor.execute(sql)
        if autocommit:
            self.cursor.commit()

    def select_all(self, sql, *params) -> list[Row]:
        self.cursor.execute(sql, *params)
        logger.info(sql)
        rows = self.cursor.fetchall()
        return [row for row in rows]

    def select_one(self, sql, *params):
        # 单一字段单一值查询,多值多字段仅返回第一行第一列
        logger.debug(sql)
        self.cursor.execute(sql, *params)
        return self.cursor.fetchval()

    def fetch_one(self, sql, columns: list = None, *args) -> typing.Dict:
        logger.debug(sql)
        self.cursor.execute(sql, *args)
        record = self.cursor.fetchone()
        if record is None:
            return None
        return dict(zip(columns, record))

    def fetch_all(self, sql, columns: list = None, *args) -> typing.List[dict]:
        logger.debug(sql)
        self.cursor.execute(sql, *args)
        record_list = self.cursor.fetchall()
        return [dict(zip(columns, row)) for row in record_list]

    def select_ones(self, sql, *params) -> set:
        # 单一列不重复数据
        logger.debug(sql)
        self.cursor.execute(sql, *params)
        return {row[0] for row in self.cursor.fetchall()}

    def update(self, sql, *params, auto_commit=False):
        logger.debug(sql)
        if auto_commit:
            conn, cursor = self._connect()
            cursor.execute(sql, *params)
            conn.commit()
            self.connect_close(conn, cursor)
        else:
            self.cursor.execute(sql, *params)

    def updateCase(self, table_name, key, set_key, data, auto_commit=False):
        if len(data) == 0:
            return False
        for _data in split_array(data, 100):
            sql = f"update {table_name} set {set_key} = case {' '.join([f' when {key} = ? then ?' for _ in range(len(_data))])} end  where {key} in ({('?,' * len(_data)).rstrip(',')})"
            logger.info(sql)
            if auto_commit:
                conn, cursor = self._connect()
                cursor.execute(sql, [k for i in _data for k in i] + [i[0] for i in _data])
                conn.commit()
                self.connect_close(conn, cursor)
            else:
                self.cursor.execute(sql, [k for i in _data for k in i] + [i[0] for i in _data])
        return True

    def get_table_columns(self, table_name: str) -> list:
        self.cursor.execute("""
                            select COLUMN_NAME
                            from INFORMATION_SCHEMA.COLUMNS
                            where TABLE_NAME = ?
                        """, table_name)
        rows = self.cursor.fetchall()
        return [row[0] for row in rows]

    def insert_many(self, table: str, columns: list, data: list[tuple], autocommit=True):
        if data is None or len(data) == 0:
            return
        sql = f'insert into {table}({",".join(columns)}) values ({"?," * len(columns)}'.rstrip(",") + ')'
        self.cursor.fast_executemany = True
        logger.info(f"{sql} \n count: {len(data)}")
        try:
            self.cursor.executemany(sql, data)
        except Exception as e:
            if settings.active == "dev":
                logger.error(data)
            raise e
        if autocommit:
            self.cursor.commit()

    def insert(self, table: str, columns: list, data: list[tuple], auto_commit=False):
        sql = f'insert into {table}({",".join(columns)}) values ({"?," * len(columns)}'.rstrip(",") + ')'
        if auto_commit:
            conn, cursor = self._connect()
            cursor.executemany(sql, data)
            conn.commit()
            self.connect_close(conn, cursor)
        else:
            self.cursor.executemany(sql, data)

    def insertOrUpdateById(self, table: str, columns: list, data: list[dict], key: str, autocommit=False):
        ids = [str(i.get(key)) for i in data]
        select = f"select {key} from {table} where {key} in ({('?,' * len(ids)).rstrip(',')}) group by {key}"
        rows = self.select_all(select, ids)
        tmp_ids = [i[0] for i in rows]

        sql = f'insert into {table}({",".join(columns)}) values ({"?," * len(columns)}'.rstrip(",") + ')'
        logger.info(sql)
        self.cursor.fast_executemany = True
        save_data = [i for i in data if i.get(key) not in tmp_ids]
        update_data = [i for i in data if i.get(key) in tmp_ids]
        if len(save_data) > 0:
            self.cursor.executemany(sql, [[i.get(c) for c in columns] for i in save_data])
        if len(update_data) > 0:
            if key in columns:
                columns = copy.deepcopy(columns)
                columns.pop(columns.index(key))
            self.updateBatchById(table, columns, data, key, autocommit)
        if autocommit:
            self.cursor.commit()

    def insertByKeys(self, table: str, columns: list, data: list[dict], keys: list[str], autocommit=True):
        if len(data) == 0:
            return
        self.cursor.fast_executemany = True
        keys_map = {key: {i.get(key) for i in data if i.get(key) is not None} for key in keys}
        key_str = [f"{key} in ({('?,' * len(keys_map.get(key))).rstrip(',')})" for key in keys if
                   len(keys_map.get(key)) > 0]
        select = f"select {','.join(keys)} from {table} {'where' if len(key_str) > 0 else ''} {' and '.join(key_str)} group by {','.join(keys)}"
        rows: list[dict] = self.fetch_all(select, keys, [k for i in keys_map.values() for k in i])
        save_data = []
        for i in data:
            is_save = True
            for row in rows:
                tmp = []
                tmp2 = []
                for key in keys:
                    tmp.append(str(row.get(key)))
                    tmp2.append(str(i.get(key)))
                if tmp == tmp2:
                    is_save = False
                    break
            if is_save:
                save_data.append(i)
        sql = f'insert into {table}({",".join(columns)}) values ({"?," * len(columns)}'.rstrip(",") + ')'
        logger.info(f"{sql} \n 本次插入{len(save_data)}条")
        if len(save_data) > 0:
            self.cursor.executemany(sql, [[i.get(c) for c in columns] for i in save_data])
        if autocommit:
            self.cursor.commit()

    def updateBatchById(self, table, columns, data: list[dict], key, autocommit=False):
        if not data:
            return False
        sql = f"update {table} set {','.join([f' {i} = ? ' for i in columns])} where {key} = ?"
        logger.info(sql)
        self.cursor.fast_executemany = True
        _data = [[i.get(c) for c in columns] + [i.get(key)] for i in data]
        self.cursor.executemany(sql, _data)
        if autocommit:
            self.cursor.commit()

    def commit(self):
        self.conn.commit()
        logger.info("commit")

    def close(self):
        self.conn.close()
        logger.info("close")

    def connect_close(self, conn, cursor):
        cursor.close()
        conn.close()

    def rollback(self):
        self.conn.rollback()
        logger.info("rollback")

